# -*- coding: UTF-8 -*-
# *******************************************************************
# File Name: img_to_json
# > Author: 04000387
# > Created Time: 2024/12/25 13:30
# *******************************************************************

from torch import nn


class ImgToJsonModel(nn.Module):
    def __init__(self, backbone, decoder, head):
        """

        :param backbone: 图像encoder
        :param decoder: 文字decoder
        :param head: 输出头
        """
        super().__init__()
        self.backbone = backbone
        self.decoder = decoder
        self.head = head

    def forward(self, img, input_ids, attention_mask):
        x = self.backbone(img)

        x = self.decoder(memory=x, input_ids=input_ids, attention_mask=attention_mask)

        x = self.head(x)

        return x
